# -*- coding: utf-8 -*-
# @Author: Wenwen Yu
# @Created Time: 10/4/2020 14:18

from copy import deepcopy

import numpy as np
import torch
from torch import nn

from .backbone import ConvEmbeddingGC
from .transformer import MultiHeadAttention, PositionwiseFeedForward, PositionalEncoding, EncoderLayer, Encoder, \
    DecoderLayer, Decoder, Embeddings
from utils.label_util import LabelTransformer


class MASTER(nn.Module):
    """
     A standard Encoder-Decoder MASTER architecture.
    """

    def __init__(self, **kwargs):

        super(MASTER, self).__init__()

        common_kwargs = kwargs['common_kwargs']
        backbone_kwargs = kwargs['backbone_kwargs']
        encoder_kwargs = kwargs['encoder_kwargs']
        decoder_kwargs = kwargs['decoder_kwargs']

        # with encoder: cnn(+gc block) + transformer encoder + transformer decoder
        # without encoder: cnn(+gc block) + transformer decoder
        self.with_encoder = common_kwargs['with_encoder']

        self.build_model(common_kwargs, backbone_kwargs, encoder_kwargs, decoder_kwargs)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def build_model(self, common_kwargs, backbone_kwargs, encoder_kwargs, decoder_kwargs):
        common_kwargs['n_class'] = LabelTransformer.nclass
        tgt_vocab = common_kwargs['n_class']

        # shared
        d_model = common_kwargs['model_size']
        h = common_kwargs['multiheads']

        # encoder cfg
        encoders = encoder_kwargs['stacks']
        encoder_dropout = encoder_kwargs['dropout']
        encoder_d_ff = encoder_kwargs['feed_forward_size']
        # decodef cfg
        decoders = decoder_kwargs['stacks']
        decoder_dropout = decoder_kwargs['dropout']
        decoder_d_ff = decoder_kwargs['feed_forward_size']

        encoder_attn = MultiHeadAttention(h, d_model, encoder_dropout)
        encoder_ff = PositionwiseFeedForward(d_model, encoder_d_ff, encoder_dropout)
        encoder_position = PositionalEncoding(d_model, encoder_dropout)

        decoder_attn = MultiHeadAttention(h, d_model, decoder_dropout)
        decoder_ff = PositionwiseFeedForward(d_model, decoder_d_ff, decoder_dropout)
        decoder_position = PositionalEncoding(d_model, decoder_dropout)

        conv_embedding_gc = ConvEmbeddingGC(**backbone_kwargs)

        if self.with_encoder:
            encoder = Encoder(EncoderLayer(d_model, self_attn=deepcopy(encoder_attn), feed_forward=deepcopy(encoder_ff),
                                           dropout=encoder_dropout), encoders)
        else:
            encoder = None
        decoder = Decoder(DecoderLayer(d_model, self_attn=deepcopy(decoder_attn), src_attn=deepcopy(decoder_attn),
                                       feed_forward=deepcopy(decoder_ff),
                                       dropout=decoder_dropout), decoders)
        src_embed = nn.Sequential(conv_embedding_gc, deepcopy(encoder_position))
        tgt_embed = nn.Sequential(Embeddings(d_model, tgt_vocab), deepcopy(decoder_position))
        generator = Generator(d_model, tgt_vocab)
        padding = LabelTransformer.PAD

        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        self.padding = padding

    def make_mask(self, src, tgt):
        """

        :param src: [b, c, h, len_src]
        :param tgt: [b, l_tgt]
        :return:
        """

        # src_mask does not need, since the embedding generated by ConvNet is dense.
        trg_pad_mask = (tgt != self.padding).unsqueeze(1).unsqueeze(3)  # (b, 1, len_src, 1)

        tgt_len = tgt.size(1)
        trg_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), dtype=torch.uint8, device=src.device))

        tgt_mask = trg_pad_mask & trg_sub_mask.bool()
        return None, tgt_mask

    def forward(self, *input):
        src = input[0]  # (b, c, h, w)
        tgt = input[1]  # (b, len_tgt) target or query, input of decoder

        src_mask, tgt_mask = self.make_mask(src, tgt)
        # output = self.decode(src, src_mask, tgt, tgt_mask)
        output = self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
        return self.generator(output)

    def encode(self, src, src_mask):
        """

        :param src:
        :param src_mask:
        :return:
        """
        if self.with_encoder:  # cnn + encoder + decoder
            return self.encoder(self.src_embed(src), src_mask)
        else:  # cnn + decoder
            return self.src_embed(src)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        """

        :param memory: output from encoder
        :param src_mask:
        :param tgt: raw target input (label of text squence)
        :param tgt_mask: [b, h, len_seq, len_seq]
        :return:
        """

        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    def __str__(self):
        '''
        Model prints with number of trainable parameters
        '''
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return super().__str__() + '\nTrainable parameters: {}'.format(params)

    def model_parameters(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return params


class Generator(nn.Module):
    """
    Define standard linear + softmax generation step.
    """

    def __init__(self, hidden_dim, vocab_size):
        """

        :param hidden_dim: dim of model
        :param vocab_size: size of vocabulary
        """
        super(Generator, self).__init__()

        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, *input):
        x = input[0]
        return self.fc(x)
